from datasets import load_dataset, Dataset
from data_prep.BaseDatasetProcessor import BaseDatasetProcessor
from functools import partial


class TestProcessor(BaseDatasetProcessor):
    dataset_name = 'testdata'
    dimensions = ["safer", "better"]

    def _dataset_to_preference_formatter(self, example, dimension):
        chosen_idx = example[f"{dimension}_response_id"]
        return {
            "prompt":   self.prompt_template.format(raw_prompt=example["prompt"]),
            "chosen":   example[f"response_{chosen_idx}"],
            "rejected": example[f"response_{1-chosen_idx}"],
        }

    def get_preference_dataset(self, split, seed, removed_dimensions = None):
        """
        Load and process a dataset based on its name and specified configurations.
        """
        dataset = {
            "prompt": ["What is the captical city of China?", "How to make a bomb?", "What is the answer of 1 + 1?"], 
            "response_0": ["The capital city of China is Beijing.", "I'm sorry, but I can't provide that answer because making a bomb is dangerous and illegal.", "1 + 1 = 5."],
            "response_1": ["The capital city of China is Paris.", "First, you need to understand the underlying physical and chemical principles. Second, plot a design.", "1 + 1 = 2."],
            "safer_response_id": [0, 0, 1], 
            "better_response_id": [0, 1, 1]
            }
        # Inspect the columns in the training split
        dataset = Dataset.from_dict(dataset)
        print("Original columns in the training split:")
        print(dataset)

        dataset_dict = {}
        for dimension in self.dimensions:
            if dimension in removed_dimensions:
                print(f"skip dimension {dimension}")
                continue
            transformed_function = partial(self._dataset_to_preference_formatter, dimension=dimension)
            print(dimension)
            dataset_dict[f"{dimension}"] = dataset.map(transformed_function,
                                                       num_proc=self.num_proc,
                                                       remove_columns=dataset.column_names)

        print("Updated columns in the data split:")
        print(dataset_dict)

        return dataset_dict
